[Wav2Vec2 Conformer] Fix inference float16#25985
[Wav2Vec2 Conformer] Fix inference float16#25985sanchit-gandhi merged 4 commits intohuggingface:mainfrom
Conversation
|
|
||
| @slow | ||
| @require_torch_gpu | ||
| def test_wav2vec2_conformer_float16(self): |
There was a problem hiding this comment.
This is the error repro that was failing before @Vaibhavs10 - added a slow integration test to make sure this works after the fix
| return self.cached_rotary_positional_embedding | ||
|
|
||
| self.cached_sequence_length = sequence_length | ||
| # Embeddings are computed in the dtype of the inv_freq constant |
There was a problem hiding this comment.
There was a problem hiding this comment.
This now looks a lot like:
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)Wondering if we can add copied from and use this / wondering if the dynamic scaling could also work for audio models?
There was a problem hiding this comment.
Can't use # Copied from on the whole module since the Wav2Vec2ConformerRotaryPositionalEmbedding accepts the config as an argument, but LlamaRotaryEmbedding uses various ad-hoc arguments. But we could do a similar dynamic slicing - will add this in a follow-up PR so as not to block @Vaibhavs10
|
The documentation is not available anymore as the PR was closed or merged. |
ylacombe
left a comment
There was a problem hiding this comment.
LGTM ! Thanks for taking care of this!
ArthurZucker
left a comment
There was a problem hiding this comment.
Looks good to me! Left a nit, but I think we can use the LlamaRotary class now 😄
| return self.cached_rotary_positional_embedding | ||
|
|
||
| self.cached_sequence_length = sequence_length | ||
| # Embeddings are computed in the dtype of the inv_freq constant |
There was a problem hiding this comment.
This now looks a lot like:
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)Wondering if we can add copied from and use this / wondering if the dynamic scaling could also work for audio models?
* [Wav2Vec2 Conformer] Fix inference float16 * fix test * fix test more * clean pipe test
What does this PR do?
Fixes #25964 - the Wav2Vec2 conformer model with rotary embeddings now works when we load it
from_pretrainedwith float16. The issue was originating in the rotary embedding layer, which was returning the positional embeddings in float32 always